{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Camera"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from jetcam.csi_camera import CSICamera\n",
    "# from jetcam.usb_camera import USBCamera\n",
    "\n",
    "camera = CSICamera(width=224, height=224)\n",
    "# camera = USBCamera(width=224, height=224)\n",
    "\n",
    "camera.running = True"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Task"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchvision.transforms as transforms\n",
    "from xy_dataset import XYDataset\n",
    "\n",
    "TASK = 'road_following'\n",
    "\n",
    "CATEGORIES = ['apex']\n",
    "\n",
    "DATASETS = ['A', 'B']\n",
    "\n",
    "TRANSFORMS = transforms.Compose([\n",
    "    transforms.ColorJitter(0.2, 0.2, 0.2, 0.2),\n",
    "    transforms.Resize((224, 224)),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
    "])\n",
    "\n",
    "datasets = {}\n",
    "for name in DATASETS:\n",
    "    datasets[name] = XYDataset(TASK + '_' + name, CATEGORIES, TRANSFORMS, random_hflip=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Data Collection"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import cv2\n",
    "import ipywidgets\n",
    "import traitlets\n",
    "from IPython.display import display\n",
    "from jetcam.utils import bgr8_to_jpeg\n",
    "from jupyter_clickable_image_widget import ClickableImageWidget\n",
    "\n",
    "\n",
    "# initialize active dataset\n",
    "dataset = datasets[DATASETS[0]]\n",
    "\n",
    "# unobserve all callbacks from camera in case we are running this cell for second time\n",
    "camera.unobserve_all()\n",
    "\n",
    "# create image preview\n",
    "camera_widget = ClickableImageWidget(width=camera.width, height=camera.height)\n",
    "snapshot_widget = ipywidgets.Image(width=camera.width, height=camera.height)\n",
    "traitlets.dlink((camera, 'value'), (camera_widget, 'value'), transform=bgr8_to_jpeg)\n",
    "\n",
    "# create widgets\n",
    "dataset_widget = ipywidgets.Dropdown(options=DATASETS, description='dataset')\n",
    "category_widget = ipywidgets.Dropdown(options=dataset.categories, description='category')\n",
    "count_widget = ipywidgets.IntText(description='count')\n",
    "\n",
    "# manually update counts at initialization\n",
    "count_widget.value = dataset.get_count(category_widget.value)\n",
    "\n",
    "# sets the active dataset\n",
    "def set_dataset(change):\n",
    "    global dataset\n",
    "    dataset = datasets[change['new']]\n",
    "    count_widget.value = dataset.get_count(category_widget.value)\n",
    "dataset_widget.observe(set_dataset, names='value')\n",
    "\n",
    "# update counts when we select a new category\n",
    "def update_counts(change):\n",
    "    count_widget.value = dataset.get_count(change['new'])\n",
    "category_widget.observe(update_counts, names='value')\n",
    "\n",
    "\n",
    "def save_snapshot(_, content, msg):\n",
    "    if content['event'] == 'click':\n",
    "        data = content['eventData']\n",
    "        x = data['offsetX']\n",
    "        y = data['offsetY']\n",
    "        \n",
    "        # save to disk\n",
    "        dataset.save_entry(category_widget.value, camera.value, x, y)\n",
    "        \n",
    "        # display saved snapshot\n",
    "        snapshot = camera.value.copy()\n",
    "        snapshot = cv2.circle(snapshot, (x, y), 8, (0, 255, 0), 3)\n",
    "        snapshot_widget.value = bgr8_to_jpeg(snapshot)\n",
    "        count_widget.value = dataset.get_count(category_widget.value)\n",
    "        \n",
    "camera_widget.on_msg(save_snapshot)\n",
    "\n",
    "data_collection_widget = ipywidgets.VBox([\n",
    "    ipywidgets.HBox([camera_widget, snapshot_widget]),\n",
    "    dataset_widget,\n",
    "    category_widget,\n",
    "    count_widget\n",
    "])\n",
    "\n",
    "display(data_collection_widget)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "\n",
    "device = torch.device('cuda')\n",
    "output_dim = 2 * len(dataset.categories)  # x, y coordinate for each category\n",
    "\n",
    "# ALEXNET\n",
    "# model = torchvision.models.alexnet(pretrained=True)\n",
    "# model.classifier[-1] = torch.nn.Linear(4096, output_dim)\n",
    "\n",
    "# SQUEEZENET \n",
    "# model = torchvision.models.squeezenet1_1(pretrained=True)\n",
    "# model.classifier[1] = torch.nn.Conv2d(512, output_dim, kernel_size=1)\n",
    "# model.num_classes = len(dataset.categories)\n",
    "\n",
    "# RESNET 18\n",
    "model = torchvision.models.resnet18(pretrained=True)\n",
    "model.fc = torch.nn.Linear(512, output_dim)\n",
    "\n",
    "# RESNET 34\n",
    "# model = torchvision.models.resnet34(pretrained=True)\n",
    "# model.fc = torch.nn.Linear(512, output_dim)\n",
    "\n",
    "# DENSENET 121\n",
    "# model = torchvision.models.densenet121(pretrained=True)\n",
    "# model.classifier = torch.nn.Linear(model.num_features, output_dim)\n",
    "\n",
    "model = model.to(device)\n",
    "\n",
    "model_save_button = ipywidgets.Button(description='save model')\n",
    "model_load_button = ipywidgets.Button(description='load model')\n",
    "model_path_widget = ipywidgets.Text(description='model path', value='road_following_model.pth')\n",
    "\n",
    "def load_model(c):\n",
    "    model.load_state_dict(torch.load(model_path_widget.value))\n",
    "model_load_button.on_click(load_model)\n",
    "    \n",
    "def save_model(c):\n",
    "    torch.save(model.state_dict(), model_path_widget.value)\n",
    "model_save_button.on_click(save_model)\n",
    "\n",
    "model_widget = ipywidgets.VBox([\n",
    "    model_path_widget,\n",
    "    ipywidgets.HBox([model_load_button, model_save_button])\n",
    "])\n",
    "\n",
    "\n",
    "display(model_widget)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Live Execution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import threading\n",
    "import time\n",
    "from utils import preprocess\n",
    "import torch.nn.functional as F\n",
    "\n",
    "state_widget = ipywidgets.ToggleButtons(options=['stop', 'live'], description='state', value='stop')\n",
    "prediction_widget = ipywidgets.Image(format='jpeg', width=camera.width, height=camera.height)\n",
    "\n",
    "def live(state_widget, model, camera, prediction_widget):\n",
    "    global dataset\n",
    "    while state_widget.value == 'live':\n",
    "        image = camera.value\n",
    "        preprocessed = preprocess(image)\n",
    "        output = model(preprocessed).detach().cpu().numpy().flatten()\n",
    "        category_index = dataset.categories.index(category_widget.value)\n",
    "        x = output[2 * category_index]\n",
    "        y = output[2 * category_index + 1]\n",
    "        \n",
    "        x = int(camera.width * (x / 2.0 + 0.5))\n",
    "        y = int(camera.height * (y / 2.0 + 0.5))\n",
    "        \n",
    "        prediction = image.copy()\n",
    "        prediction = cv2.circle(prediction, (x, y), 8, (255, 0, 0), 3)\n",
    "        prediction_widget.value = bgr8_to_jpeg(prediction)\n",
    "            \n",
    "def start_live(change):\n",
    "    if change['new'] == 'live':\n",
    "        execute_thread = threading.Thread(target=live, args=(state_widget, model, camera, prediction_widget))\n",
    "        execute_thread.start()\n",
    "\n",
    "state_widget.observe(start_live, names='value')\n",
    "\n",
    "live_execution_widget = ipywidgets.VBox([\n",
    "    prediction_widget,\n",
    "    state_widget\n",
    "])\n",
    "\n",
    "display(live_execution_widget)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "BATCH_SIZE = 8\n",
    "\n",
    "optimizer = torch.optim.Adam(model.parameters())\n",
    "# optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)\n",
    "\n",
    "epochs_widget = ipywidgets.IntText(description='epochs', value=1)\n",
    "eval_button = ipywidgets.Button(description='evaluate')\n",
    "train_button = ipywidgets.Button(description='train')\n",
    "loss_widget = ipywidgets.FloatText(description='loss')\n",
    "progress_widget = ipywidgets.FloatProgress(min=0.0, max=1.0, description='progress')\n",
    "\n",
    "def train_eval(is_training):\n",
    "    global BATCH_SIZE, LEARNING_RATE, MOMENTUM, model, dataset, optimizer, eval_button, train_button, accuracy_widget, loss_widget, progress_widget, state_widget\n",
    "    \n",
    "    try:\n",
    "        train_loader = torch.utils.data.DataLoader(\n",
    "            dataset,\n",
    "            batch_size=BATCH_SIZE,\n",
    "            shuffle=True\n",
    "        )\n",
    "\n",
    "        state_widget.value = 'stop'\n",
    "        train_button.disabled = True\n",
    "        eval_button.disabled = True\n",
    "        time.sleep(1)\n",
    "\n",
    "        if is_training:\n",
    "            model = model.train()\n",
    "        else:\n",
    "            model = model.eval()\n",
    "\n",
    "        while epochs_widget.value > 0:\n",
    "            i = 0\n",
    "            sum_loss = 0.0\n",
    "            error_count = 0.0\n",
    "            for images, category_idx, xy in iter(train_loader):\n",
    "                # send data to device\n",
    "                images = images.to(device)\n",
    "                xy = xy.to(device)\n",
    "\n",
    "                if is_training:\n",
    "                    # zero gradients of parameters\n",
    "                    optimizer.zero_grad()\n",
    "\n",
    "                # execute model to get outputs\n",
    "                outputs = model(images)\n",
    "\n",
    "                # compute MSE loss over x, y coordinates for associated categories\n",
    "                loss = 0.0\n",
    "                for batch_idx, cat_idx in enumerate(list(category_idx.flatten())):\n",
    "                    loss += torch.mean((outputs[batch_idx][2 * cat_idx:2 * cat_idx+2] - xy[batch_idx])**2)\n",
    "                loss /= len(category_idx)\n",
    "\n",
    "                if is_training:\n",
    "                    # run backpropogation to accumulate gradients\n",
    "                    loss.backward()\n",
    "\n",
    "                    # step optimizer to adjust parameters\n",
    "                    optimizer.step()\n",
    "\n",
    "                # increment progress\n",
    "                count = len(category_idx.flatten())\n",
    "                i += count\n",
    "                sum_loss += float(loss)\n",
    "                progress_widget.value = i / len(dataset)\n",
    "                loss_widget.value = sum_loss / i\n",
    "                \n",
    "            if is_training:\n",
    "                epochs_widget.value = epochs_widget.value - 1\n",
    "            else:\n",
    "                break\n",
    "    except e:\n",
    "        pass\n",
    "    model = model.eval()\n",
    "\n",
    "    train_button.disabled = False\n",
    "    eval_button.disabled = False\n",
    "    state_widget.value = 'live'\n",
    "    \n",
    "train_button.on_click(lambda c: train_eval(is_training=True))\n",
    "eval_button.on_click(lambda c: train_eval(is_training=False))\n",
    "    \n",
    "train_eval_widget = ipywidgets.VBox([\n",
    "    epochs_widget,\n",
    "    progress_widget,\n",
    "    loss_widget,\n",
    "    ipywidgets.HBox([train_button, eval_button])\n",
    "])\n",
    "\n",
    "display(train_eval_widget)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### All together!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The following widget can be used to label a multi-class x, y dataset.  It supports labeling only one instance of each class per image (ie: only one dog), but multiple classes (ie: dog, cat, horse) per image are possible.\n",
    "\n",
    "Click the image on the top left to save an image of ``category`` to ``dataset`` at the clicked location.\n",
    "\n",
    "| Widget | Description |\n",
    "|--------|-------------|\n",
    "| dataset | Selects the active dataset |\n",
    "| category | Selects the active category |\n",
    "| epochs | Sets the number of epochs to train for |\n",
    "| train | Trains on the active dataset for the number of epochs specified |\n",
    "| evaluate | Evaluates the accuracy on the active dataset over one epoch |\n",
    "| model path | Sets the active model path |\n",
    "| load | Loads a model from the active model path |\n",
    "| save | Saves a model to the active model path |\n",
    "| stop | Disables the live demo |\n",
    "| live | Enables the live demo |"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_widget = ipywidgets.VBox([\n",
    "    ipywidgets.HBox([data_collection_widget, live_execution_widget]), \n",
    "    train_eval_widget,\n",
    "    model_widget\n",
    "])\n",
    "\n",
    "display(all_widget)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
